#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import sys
if six.PY2:
    reload(sys)
    sys.setdefaultencoding('utf8')
import os
import time
import argparse
import numpy as np
import subprocess
import multiprocessing

import paddle
import paddle.fluid as fluid

import reader.cls as reader
from model import create_model, optimization
from paddlecraft.utils import check_cuda, init_pretraining_params, init_checkpoint, get_cards, Configure


def evaluate(exe, test_program, test_data_loader, fetch_list, eval_phase):
    test_data_loader.start()
    total_cost, total_acc, total_num_seqs = [], [], []
    time_begin = time.time()
    while True:
        try:
            np_loss, np_acc, np_num_seqs = exe.run(program=test_program,
                                                   fetch_list=fetch_list)
            total_cost.extend(np_loss * np_num_seqs)
            total_acc.extend(np_acc * np_num_seqs)
            total_num_seqs.extend(np_num_seqs)
        except fluid.core.EOFException:
            test_data_loader.reset()
            break
    time_end = time.time()
    print("[%s evaluation] ave loss: %f, ave acc: %f, elapsed time: %f s" %
          (eval_phase, np.sum(total_cost) / np.sum(total_num_seqs),
           np.sum(total_acc) / np.sum(total_num_seqs), time_end - time_begin))


def get_device_num():
    # NOTE(zcd): for multi-processe training, each process use one GPU card.
    visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num


def main():
    bert_config = Configure(
        json_file="./data/pretrained_models/uncased_L-12_H-768_A-12/bert_config.json"
    )
    bert_config.build()
    print(bert_config.incr_ratio)
    bert_config.Print()

    if bert_config.use_cuda:
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
        dev_count = get_device_num()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    task_name = bert_config.task_name.lower()
    processors = {
        'xnli': reader.XnliProcessor,
        'cola': reader.ColaProcessor,
        'mrpc': reader.MrpcProcessor,
        'mnli': reader.MnliProcessor,
    }

    processor = processors[task_name](data_dir=bert_config.data_dir,
                                      vocab_path=bert_config.vocab_path,
                                      max_seq_len=bert_config.max_seq_len,
                                      do_lower_case=bert_config.do_lower_case,
                                      in_tokens=bert_config.in_tokens,
                                      random_seed=bert_config.random_seed)
    num_labels = len(processor.get_labels())

    if not (bert_config.do_train or bert_config.do_val or bert_config.do_test):
        raise ValueError("For configure `do_train`, `do_val` and `do_test`, at "
                         "least one of them must be True.")
    train_program = fluid.Program()
    startup_prog = fluid.Program()
    if bert_config.random_seed is not None:
        startup_prog.random_seed = bert_config.random_seed
        train_program.random_seed = bert_config.random_seed

    if bert_config.do_train:
        # NOTE: If num_trainers > 1, the shuffle_seed must be set, because
        # the order of batch data generated by reader
        # must be the same in the respective processes.
        shuffle_seed = None
        train_data_generator = processor.data_generator(
            batch_size=bert_config.batch_size,
            phase='train',
            epoch=bert_config.epoch,
            dev_count=dev_count,
            shuffle=bert_config.shuffle,
            shuffle_seed=shuffle_seed)

        num_train_examples = processor.get_num_examples(phase='train')

        if bert_config.in_tokens:
            max_train_steps = bert_config.epoch * num_train_examples // (
                bert_config.batch_size // bert_config.max_seq_len) // dev_count
        else:
            max_train_steps = bert_config.epoch * num_train_examples // bert_config.batch_size // dev_count

        warmup_steps = int(max_train_steps * bert_config.warmup_proportion)
        print("Device count: %d" % dev_count)
        print("Num train examples: %d" % num_train_examples)
        print("Max train steps: %d" % max_train_steps)
        print("Num warmup steps: %d" % warmup_steps)

        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                train_data_loader, loss, probs, accuracy, num_seqs = create_model(
                    bert_config=bert_config, num_labels=num_labels)
                scheduled_lr, loss_scaling = optimization(
                    loss=loss,
                    warmup_steps=warmup_steps,
                    num_train_steps=max_train_steps,
                    learning_rate=bert_config.learning_rate,
                    train_program=train_program,
                    startup_prog=startup_prog,
                    weight_decay=bert_config.weight_decay,
                    scheduler=bert_config.lr_scheduler,
                    use_fp16=False,
                    use_dynamic_loss_scaling=bert_config.
                    use_dynamic_loss_scaling,
                    init_loss_scaling=bert_config.init_loss_scaling,
                    incr_every_n_steps=bert_config.incr_every_n_steps,
                    decr_every_n_nan_or_inf=bert_config.decr_every_n_nan_or_inf,
                    incr_ratio=bert_config.incr_ratio,
                    decr_ratio=bert_config.decr_ratio)

    if bert_config.do_val:
        dev_prog = fluid.Program()
        with fluid.program_guard(dev_prog, startup_prog):
            with fluid.unique_name.guard():
                dev_data_loader, loss, probs, accuracy, num_seqs = create_model(
                    bert_config=bert_config, num_labels=num_labels)

        dev_prog = dev_prog.clone(for_test=True)
        dev_data_loader.set_batch_generator(
            processor.data_generator(
                batch_size=bert_config.batch_size,
                phase='dev',
                epoch=1,
                dev_count=1,
                shuffle=False),
            place)

    if bert_config.do_test:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_data_loader, loss, probs, accuracy, num_seqs = create_model(
                    bert_config=bert_config, num_labels=num_labels)

        test_prog = test_prog.clone(for_test=True)
        test_data_loader.set_batch_generator(
            processor.data_generator(
                batch_size=bert_config.batch_size,
                phase='test',
                epoch=1,
                dev_count=1,
                shuffle=False),
            place)

    exe.run(startup_prog)

    if bert_config.do_train:
        if bert_config.init_checkpoint and bert_config.init_pretraining_params:
            print(
                "WARNING: config 'init_checkpoint' and 'init_pretraining_params' "
                "both are set! Only arg 'init_checkpoint' is made valid.")
        if bert_config.init_checkpoint:
            init_checkpoint(
                exe,
                bert_config.init_checkpoint,
                main_program=startup_prog,
                use_fp16=False)
        elif bert_config.init_pretraining_params:
            init_pretraining_params(
                exe,
                bert_config.init_pretraining_params,
                main_program=startup_prog,
                use_fp16=False)

    elif bert_config.do_val or bert_config.do_test:
        if not bert_config.init_checkpoint:
            raise ValueError("config 'init_checkpoint' should be set if"
                             "only doing validation or testing!")
        init_checkpoint(
            exe,
            bert_config.init_checkpoint,
            main_program=startup_prog,
            use_fp16=False)

    if bert_config.do_train:
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = bert_config.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = bert_config.num_iteration_per_drop_scope
        build_strategy = fluid.BuildStrategy()

        train_compiled_program = fluid.CompiledProgram(
            train_program).with_data_parallel(
                loss_name=loss.name, build_strategy=build_strategy)

        train_data_loader.set_batch_generator(train_data_generator, place)

    if bert_config.do_train:
        train_data_loader.start()
        steps = 0
        total_cost, total_acc, total_num_seqs = [], [], []
        time_begin = time.time()
        throughput = []
        ce_info = []

        total_batch_num = 0  # used for benchmark

        while True:
            try:
                steps += 1

                total_batch_num += 1  # used for benchmark
                if bert_config.max_iter and total_batch_num == bert_config.max_iter:  # used for benchmark
                    return

                if steps % bert_config.skip_steps == 0:
                    fetch_list = [
                        loss.name, accuracy.name, scheduled_lr.name,
                        num_seqs.name
                    ]
                else:
                    fetch_list = []

                outputs = exe.run(train_compiled_program, fetch_list=fetch_list)

                if steps % bert_config.skip_steps == 0:
                    np_loss, np_acc, np_lr, np_num_seqs = outputs

                    total_cost.extend(np_loss * np_num_seqs)
                    total_acc.extend(np_acc * np_num_seqs)
                    total_num_seqs.extend(np_num_seqs)

                    if bert_config.verbose:
                        verbose = "train data_loader queue size: %d, " % train_data_loader.queue.size(
                        )
                        verbose += "learning rate: %f" % np_lr[0]
                        print(verbose)

                    current_example, current_epoch = processor.get_train_progress(
                    )
                    time_end = time.time()
                    used_time = time_end - time_begin

                    log_record = "epoch: {}, progress: {}/{}, step: {}, ave loss: {}, ave acc: {}".format(
                        current_epoch, current_example, num_train_examples,
                        steps,
                        np.sum(total_cost) / np.sum(total_num_seqs),
                        np.sum(total_acc) / np.sum(total_num_seqs))
                    ce_info.append([
                        np.sum(total_cost) / np.sum(total_num_seqs),
                        np.sum(total_acc) / np.sum(total_num_seqs), used_time
                    ])
                    if steps > 0:
                        throughput.append(bert_config.skip_steps / used_time)
                        log_record = log_record + ", speed: %f steps/s" % (
                            bert_config.skip_steps / used_time)
                        print(log_record)
                    else:
                        print(log_record)
                    total_cost, total_acc, total_num_seqs = [], [], []
                    time_begin = time.time()

                if steps % bert_config.save_steps == 0:
                    save_path = os.path.join(bert_config.checkpoints,
                                             "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)

                if steps % bert_config.validation_steps == 0:
                    print("Average throughtput: %s" % (np.average(throughput)))
                    throughput = []
                    # evaluate dev set
                    if bert_config.do_val:
                        evaluate(exe, dev_prog, dev_data_loader,
                                 [loss.name, accuracy.name, num_seqs.name],
                                 "dev")
                    # evaluate test set
                    if bert_config.do_test:
                        evaluate(exe, test_prog, test_data_loader,
                                 [loss.name, accuracy.name, num_seqs.name],
                                 "test")
            except fluid.core.EOFException:
                save_path = os.path.join(bert_config.checkpoints,
                                         "step_" + str(steps))
                fluid.io.save_persistables(exe, save_path, train_program)
                train_data_loader.reset()
                break

    # final eval on dev set
    if bert_config.do_val:
        print("Final validation result:")
        evaluate(exe, dev_prog, dev_data_loader,
                 [loss.name, accuracy.name, num_seqs.name], "dev")

    # final eval on test set
    if bert_config.do_test:
        print("Final test result:")
        evaluate(exe, test_prog, test_data_loader,
                 [loss.name, accuracy.name, num_seqs.name], "test")


if __name__ == '__main__':
    main()
